#!/usr/bin/env python
# coding=utf-8
'''
Author: JiangJi
Email: johnjim0816@gmail.com
Date: 2021-12-22 10:40:05
LastEditor: JiangJi
LastEditTime: 2021-12-22 10:43:55
Discription: 
'''
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from TD3_memory import ReplayBuffer

class Actor(nn.Module):
	
	def __init__(self, input_dim, output_dim, max_action):
		'''[summary]

		Args:
			input_dim (int): 输入维度,这里等于n_states  	输入状态
			output_dim (int): 输出维度,这里等于n_actions  	输出动作
			max_action (int): action的最大值           	   连续动作有上限
		'''		
		super(Actor, self).__init__()

		self.l1 = nn.Linear(input_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, output_dim)
		self.max_action = max_action
	
	def forward(self, state):
		
		a = F.relu(self.l1(state))						# 线性层+relu
		a = F.relu(self.l2(a))							# 线性层+relu
		return self.max_action * torch.tanh(self.l3(a)) # 线性层+tanh(范围在[-1,+1])


class Critic(nn.Module):
	def __init__(self, input_dim, output_dim):
		'''[summary]

		Args:
			input_dim (int): n_states  	状态向量
			output_dim (int): n_actions  	动作向量
		'''	
		super(Critic, self).__init__()

		# Q1 architecture
		self.l1 = nn.Linear(input_dim + output_dim, 256)
		self.l2 = nn.Linear(256, 256)
		self.l3 = nn.Linear(256, 1)	# Q-function输出scalar

		# Q2 architecture
		self.l4 = nn.Linear(input_dim + output_dim, 256)
		self.l5 = nn.Linear(256, 256)
		self.l6 = nn.Linear(256, 1)	# Q-function输出scalar


	def forward(self, state, action):
		sa = torch.cat([state, action], 1)

		q1 = F.relu(self.l1(sa)) # 线性层+relu
		q1 = F.relu(self.l2(q1)) # 线性层+relu
		q1 = self.l3(q1)		 # 线性层

		q2 = F.relu(self.l4(sa))
		q2 = F.relu(self.l5(q2))
		q2 = self.l6(q2)
		return q1, q2


	def Q1(self, state, action):
		'''
			只使用Q1
		'''
		sa = torch.cat([state, action], 1)

		q1 = F.relu(self.l1(sa))
		q1 = F.relu(self.l2(q1))
		q1 = self.l3(q1)
		return q1


class TD3(object):
	def __init__(
		self,
		input_dim,   # 状态
		output_dim,	 # 动作
		max_action,  # 上限
		cfg,         # 设置
	):
		self.max_action = max_action
		self.gamma = cfg.gamma  				# 折扣
		self.lr = cfg.lr						# 学习率
		self.policy_noise = cfg.policy_noise	# policy_noise
		self.noise_clip = cfg.noise_clip 		# 平滑曲线
		self.policy_freq = cfg.policy_freq		# 2
		self.batch_size =  cfg.batch_size 		# 采样batch
		self.device = cfg.device
		self.total_it = 0						# 时间步数t

		self.actor = Actor(input_dim, output_dim, max_action).to(self.device)
		self.actor_target = copy.deepcopy(self.actor)  # target actor
		self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=3e-4) # 构建策略函数的优化器

		self.critic = Critic(input_dim, output_dim).to(self.device)
		self.critic_target = copy.deepcopy(self.critic) # target critic
		self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=3e-4) # 构建Q函数的优化器

		self.memory = ReplayBuffer(input_dim, output_dim)

	def choose_action(self, state):
		state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)

		# 全部batch的全部action都拉伸成一行（列）
		return self.actor(state).cpu().data.numpy().flatten()

	def update(self):
		self.total_it += 1

		# Sample replay buffer 
		state, action, next_state, reward, not_done = self.memory.sample(self.batch_size)

		# 根据策略选取动作的时候不能计算梯度
		with torch.no_grad():
			# Select action according to policy and add clipped noise
			noise = (
				torch.randn_like(action) * self.policy_noise
			).clamp(-self.noise_clip, self.noise_clip)
			
			next_action = (
				self.actor_target(next_state) + noise
			).clamp(-self.max_action, self.max_action)

			# Compute the target Q value
			target_Q1, target_Q2 = self.critic_target(next_state, next_action)
			# double Q network
			target_Q = torch.min(target_Q1, target_Q2)
			# 还没结束的话要计算折扣
			target_Q = reward + not_done * self.gamma * target_Q

		# Get current Q estimates
		current_Q1, current_Q2 = self.critic(state, action)

		# Compute critic loss
		critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)

		# Optimize the critic
		self.critic_optimizer.zero_grad()
		critic_loss.backward()
		self.critic_optimizer.step()

		# Delayed policy updates
		# actor_target 更新频率慢
		if self.total_it % self.policy_freq == 0:

			# Compute actor losse
			actor_loss = -self.critic.Q1(state, self.actor(state)).mean()
			
			# Optimize the actor 
			self.actor_optimizer.zero_grad()
			actor_loss.backward()
			self.actor_optimizer.step()

			# Update the frozen target models
			# 更新target_network
			for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
				target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data)

			for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
				target_param.data.copy_(self.lr * param.data + (1 - self.lr) * target_param.data)


	def save(self, path):
		torch.save(self.critic.state_dict(), path + "td3_critic")
		torch.save(self.critic_optimizer.state_dict(), path + "td3_critic_optimizer")
		
		torch.save(self.actor.state_dict(), path + "td3_actor")
		torch.save(self.actor_optimizer.state_dict(), path + "td3_actor_optimizer")


	def load(self, path):
		self.critic.load_state_dict(torch.load(path + "td3_critic"))
		self.critic_optimizer.load_state_dict(torch.load(path + "td3_critic_optimizer"))
		self.critic_target = copy.deepcopy(self.critic)

		self.actor.load_state_dict(torch.load(path + "td3_actor"))
		self.actor_optimizer.load_state_dict(torch.load(path + "td3_actor_optimizer"))
		self.actor_target = copy.deepcopy(self.actor)
		
